import torch

torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed
import os
import hydra
import torch.distributed as dist
import torch.multiprocessing as mp
from omegaconf import OmegaConf, DictConfig
import mask_trainers as trainers
import wandb
import json
import socket
from typing import Optional, Set
from huggingface_hub import login
from peft import LoraConfig, PeftModel, get_peft_model
from collections import defaultdict

dist.set_debug_level(dist.DebugLevel.OFF)
OmegaConf.register_new_resolver("get_local_run_dir",
                                lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs))
# 全局记录：key 是 layer idx，value 是一个 list，保存所有向该层 block_sparse_moe 选中的专家 id
moe_records = defaultdict(list)


class MoEHook:
    def __init__(self, layer_idx):
        self.layer_idx = layer_idx

    def __call__(self, module, inputs, output):
        # module.selected_experts: Tensor [batch, seq_len, k]
        with torch.no_grad():
            sel = module.selected_experts
            flat = sel.reshape(-1).cpu().tolist()
            moe_records[self.layer_idx].extend(flat)
            # print("selected experts are", sel)
            # print("all moe records are", moe_records)


def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module,
                reference_model: Optional[nn.Module] = None):
    """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer)."""
    if 'FSDP' in config.trainer:
        init_distributed(rank, world_size, port=config.fsdp_port)

    if rank == 0 and config.wandb.enabled:
        os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs)
        wandb.init(
            entity=config.wandb.entity,
            project=config.wandb.project,
            config=OmegaConf.to_container(config),
            dir=get_local_dir(config.local_dirs),
            name=config.exp_name,
        )

    TrainerClass = getattr(trainers, config.trainer)
    print(f'Creating trainer on process {rank} with world size {world_size}')
    trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model,
                           rank=rank, world_size=world_size)

    trainer.train()

    # print("rank ",rank, "finished all training.")

    from collections import Counter
    # for layer_idx, picks in moe_records.items():
    #     cnt = Counter(picks)
    #     total = sum(cnt.values())
    #     print(f"rank, {rank}, Layer {layer_idx}: total selections = {total}")
    #     for expert_id, c in sorted(cnt.items()):
    #         print(f"rank, {rank}, Expert {expert_id} selected {c} times ({c/total:.2%})")
    # ===== 在此处加入全局统计 =====
    overall_cnt = Counter()  # 汇总所有层
    overall_total = 0  # 总调用次数

    for picks in moe_records.values():
        overall_cnt.update(picks)
        overall_total += len(picks)

    print(f"\nrank, {rank}, Overall expert distribution across ALL layers:")
    for expert_id, c in sorted(overall_cnt.items()):
        print(f"rank, {rank}, Expert {expert_id} selected {c} times ({c / overall_total:.2%})")

    # 2. 关闭 wandb（只在 rank 0 调）
    if rank == 0 and config.wandb.enabled and wandb.run is not None:
        wandb.finish()  # flush & close background threads

    # 3. 销毁分布式进程组，释放 NCCL 资源
    if dist.is_initialized():
        dist.destroy_process_group()


@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(config: DictConfig):
    """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es)."""

    # Resolve hydra references, e.g. so we don't re-compute the run directory
    OmegaConf.resolve(config)

    missing_keys: Set[str] = OmegaConf.missing_keys(config)
    if missing_keys:
        raise ValueError(f"Got missing keys in config:\n{missing_keys}")

    print(OmegaConf.to_yaml(config))

    config_path = os.path.join(config.local_run_dir, 'config.yaml')
    with open(config_path, 'w') as f:
        OmegaConf.save(config, f)

    print('=' * 140)
    print(f'Writing to {socket.gethostname()}:{config.local_run_dir}')
    print('=' * 140)

    os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs)

    model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {}
    policy_dtype = getattr(torch, config.model.policy_dtype)

    load_path = config.model.name_or_path
    print('building policy from path', load_path)

    policy = transformers.AutoModelForCausalLM.from_pretrained(load_path, low_cpu_mem_usage=True,
                                                               use_cache=False, torch_dtype=policy_dtype,
                                                               **model_kwargs)
    # insert hook to calculate experts
    for idx, layer in enumerate(policy.model.layers):
        module = getattr(layer, "mlp", None)
        if module is not None and type(module).__name__ == "OlmoeSparseMoeBlock":
            module.register_forward_hook(MoEHook(idx))
            print(f"[Hook] Registered on model.layers[{idx}].mlp")

    tokenizer = transformers.AutoTokenizer.from_pretrained(load_path)
    if tokenizer.pad_token_id is None:
        tokenizer.add_special_tokens({'pad_token': '<PAD>'})
        policy.config.pad_token_id = tokenizer.pad_token_id
        policy.resize_token_embeddings(len(tokenizer))

    if config.model.archive is None:
        peft_config = LoraConfig(
            r=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            # target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
            # target_modules=['w1','w2','w3','gate']
            # target_modules=['w1','w2','w3']
            # target_modules= 'all-linear'
            target_modules=['gate_proj', 'up_proj', 'down_proj', 'gate']
        )
        policy = get_peft_model(policy, peft_config)
        # #lori控制部分
        # 判断 name 中是否包含数字 1、2、3 或 7,即高频专家
        # for name, param in policy.named_parameters():
        #     #has_digit = any(d in name for d in ['experts.1', 'experts.2', 'experts.3', 'experts.7'])
        #     # if ('lora_A' in name) and has_digit:
        #     #     # print("include name",name)
        #     #     param.requires_grad = False
        #     if 'lora_A' in name:
        #         param.requires_grad = False
        # Print the trainable parameters
        policy.print_trainable_parameters()
    else:
        # policy = PeftModel.from_pretrained(policy, config.model.archive)
        # print('loading from archive', config.model.archive)
        # for name, param in policy.named_parameters():
        #     if 'lora_B' in name:
        #         param.requires_grad = True
        # #Print the trainable parameters
        policy.print_trainable_parameters()

    disable_dropout(policy)

    if config.loss.name in ['dpo', 'soft_sft']:
        print('building reference model')
        reference_model_dtype = getattr(torch, config.model.reference_dtype)
        reference_model = transformers.AutoModelForCausalLM.from_pretrained(load_path, use_cache=False,
                                                                            low_cpu_mem_usage=True,
                                                                            torch_dtype=reference_model_dtype,
                                                                            **model_kwargs)
        disable_dropout(reference_model)
    else:
        reference_model = None

    if 'FSDP' in config.trainer:
        world_size = torch.cuda.device_count()
        print('starting', world_size, 'processes for FSDP training')
        mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
        print("finish lora/lori adapter training.")
    else:
        print('starting single-process worker')
        worker_main(0, 1, config, policy, reference_model)


if __name__ == '__main__':
    main()